In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.datasets as dset
from PIL import Image
from torchvision import transforms
import torchvision.utils as vutils
import numpy as np
import torch.nn.functional as F
from math import log2
import math
In [ ]:
# Root directory for dataset
dataroot = "C:\.Python Projects\AnimeGAN\data"

# Number of workers for dataloader
workers = 14

# Set True to prevent releasing and reassigning workers between epochs
persistent_workers = True

# Number of batches to prefetch per worker ( workers * prefect_factor = Number of Batches preloaded )
prefetch_factor = 4

# Batch size during training
batch_sizes = [64, 64, 64, 64, 40, 16, 8, 4] # 512: [64, 64, 64, 64, 10, 6, 4, 2], 256: [64, 64, 64, 64, 32, 16, 8, 4]

# Image size to train up to (inclusive). Paper uses 1024
image_size = 128

assert image_size >= 4, f"{image_size} is not greater than or equal to 4!"
assert image_size <= 1024, f"{image_size} is not less than or equal to 1024!"
assert math.ceil(log2(image_size)) == math.floor(log2(image_size)), f"{image_size} is not a power of 2!"

# Image size to start training at
start_train_at = 4

assert image_size >= 4, f"{start_train_at} is not greater than or equal to 4!"
assert start_train_at <= image_size, f"{start_train_at} is not less than or equal to {image_size}"
assert math.ceil(log2(start_train_at)) == math.floor(log2(start_train_at)), f"{start_train_at} is not a power of 2!"

# Device to push to
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
pin_memory = True if device.type == "cuda" else False

# Size of latent vector, Paper uses 512 for z_dim and in_channels
z_dim = 256

in_channels = 256

# Learning rate for optimizers. Paper uses 1e-3
lr = 1e-3

lambda_gp = 10

# Number of steps to reach desired image size
num_steps = int(log2(image_size / 4)) + 1

# Progressive Epochs
#   4x4: Train the model to see 800k images. 
#   8x8, 16x16, ..., img_size x img_size: Train model to see 800k images fading in the new layer and 800k images to stabilize.
#   
#   4x4 epochs: 800,000 / dataset_size
#   onwards... 2 times 4x4 epochs

prog_epochs = [32] + [64] * (num_steps - 1)

# Used to create the layers progressively.
factors = [1, 1, 1, 1, 1/2, 1/4, 1/8, 1/16, 1/32]

# Fixed noise to monitor progression of model
fixed_noise = torch.randn(64, z_dim, 1, 1).to(device)

# Display images after each epoch
display_images = False

# Use pretrained model
use_pretrained = False

# Set true to show every step with tqdm
update_last = True

start_epoch = 0

step = int(log2(start_train_at / 4))

if use_pretrained:
    path = "../PATH_TO_CHECKPOINT.pth"
    checkpoint = torch.load(path)
    batch_sizes = checkpoint["batch_sizes"]
    start_train_at = checkpoint["start_training_at"]
    fixed_noise = checkpoint["fixed_noise"]
    z_dim = checkpoint["z_dim"]
    in_channels = checkpoint["in_channels"]
    step = int(log2(start_train_at / 4))
    start_epoch = checkpoint["epoch"]
In [3]:
class WSConv2d(nn.Module):
    def __init__(self, input_channel, out_channel, kernel_size=3, stride=1, padding=1, gain=2):
        super(WSConv2d, self).__init__()
        self.conv = nn.Conv2d(input_channel, out_channel, kernel_size, stride, padding)
        self.scale = (gain / (input_channel * (kernel_size ** 2))) ** 0.5
        self.bias = self.conv.bias
        self.conv.bias = None

        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)
In [4]:
class PixelNorm(nn.Module):
    def __init__(self):
        super(PixelNorm, self).__init__()
        self.epsilon = 1e-8

    def forward(self, x):
        return x / torch.sqrt( torch.mean( x ** 2, dim=1, keepdim=True ) + self.epsilon )
In [5]:
class ConvBlock(nn.Module):
    def __init__(self, input_channel, out_channel, pixel_norm=True):
        super(ConvBlock, self).__init__()
        self.conv1 = WSConv2d(input_channel, out_channel)
        self.conv2 = WSConv2d(out_channel, out_channel)
        self.leaky = nn.LeakyReLU(0.2)
        self.pn = PixelNorm()
        self.use_pn = pixel_norm

    def forward(self, x):
        x = self.leaky(self.conv1(x))
        x = self.pn(x) if self.use_pn else x
        x = self.leaky(self.conv2(x))
        x = self.pn(x) if self.use_pn else x
        return x
In [6]:
class Generator(nn.Module):
    def __init__(self, z_dim, in_channels, img_channels=3):
        super(Generator, self).__init__()
        self.initial = nn.Sequential(
            PixelNorm(),
            nn.ConvTranspose2d(z_dim, in_channels, 4, 1, 0), # 1x1 -> 4x4
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2),
            PixelNorm()
        )

        self.initial_rgb = WSConv2d(in_channels, img_channels, kernel_size=1, stride=1, padding=0)
        self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([self.initial_rgb])

        for i in range(len(factors) - 1):
            conv_in_c = int(in_channels * factors[i])
            conv_out_c = int(in_channels * factors[i+1])
            self.prog_blocks.append(ConvBlock(conv_in_c, conv_out_c))
            self.rgb_layers.append(WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0))

    def fade_in(self, alpha, upscaled, generated):
        return torch.tanh(alpha * generated + (1 - alpha) * upscaled)

    def forward(self, x, alpha, steps):
        out = self.initial(x)
        
        if steps == 0:
            return self.initial_rgb(out)
        
        for i in range(steps):
            upscaled = F.interpolate(out, scale_factor=2, mode="nearest")
            out = self.prog_blocks[i](upscaled)
        
        final_upscaled = self.rgb_layers[steps-1](upscaled)
        final_out = self.rgb_layers[steps](out)
        return self.fade_in(alpha, final_upscaled, final_out)
In [7]:
class Discriminator(nn.Module):
    def __init__(self, in_channels, img_channels=3):
        super(Discriminator, self).__init__()
        self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
        self.leaky = nn.LeakyReLU(0.2)

        for i in range(len(factors) - 1, 0, -1):
            conv_in_c = int(in_channels * factors[i])
            conv_out_c = int(in_channels * factors[i-1])
            self.prog_blocks.append(ConvBlock(conv_in_c, conv_out_c, pixel_norm=False))
            self.rgb_layers.append(WSConv2d(img_channels, conv_in_c, kernel_size=1, stride=1, padding=0))

        # 4x4 img res
        self.initial_rgb = WSConv2d(img_channels, in_channels, kernel_size=1, stride=1, padding=0)
        self.rgb_layers.append(self.initial_rgb)
        self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)

        self.final_block = nn.Sequential(
            WSConv2d(in_channels + 1, in_channels, kernel_size=3, padding=1, stride=1),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, 1, kernel_size=1, padding=0, stride=1),
        )

    def fade_in(self, alpha, downscaled, out):
        return alpha * out + (1 - alpha) * downscaled

    def minibatch_std(self, x: torch.Tensor):
        batch_statistics = (
            torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
        )
        return torch.cat([x, batch_statistics], dim=1)

    def forward(self, x, alpha, steps):
        
        cur_step = len(self.prog_blocks) - steps

        out = self.leaky(self.rgb_layers[cur_step](x))

        if steps == 0:  # i.e, image is 4x4
            out = self.minibatch_std(out)
            return self.final_block(out).view(out.shape[0], -1)

        downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x)))
        out = self.avg_pool(self.prog_blocks[cur_step](out))

        out = self.fade_in(alpha, downscaled, out)

        for step in range(cur_step + 1, len(self.prog_blocks)):
            out = self.prog_blocks[step](out)
            out = self.avg_pool(out)

        out = self.minibatch_std(out)
        return self.final_block(out).view(out.shape[0], -1)
In [8]:
gen = Generator(z_dim=z_dim, in_channels=in_channels)
disc = Discriminator(in_channels=in_channels)

for img_size in [4, 8, 16, 32, 64, 128, 256, 512]:
    num_steps = int(log2(img_size / 4))
    x = torch.randn((1, z_dim, 1, 1))
    z = gen(x, 0.5, steps=num_steps)
    assert z.shape == (1, 3, img_size, img_size)
    out = disc(z, alpha=0.5, steps=num_steps)
    print(f"Success! at img size: {img_size}")
C:\Users\palge\AppData\Local\Temp\ipykernel_14344\2363565070.py:31: UserWarning: std(): degrees of freedom is <= 0. Correction should be strictly less than the reduction factor (input numel divided by output numel). (Triggered internally at ..\aten\src\ATen\native\ReduceOps.cpp:1760.)
  torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
Success! at img size: 4
Success! at img size: 8
Success! at img size: 16
Success! at img size: 32
Success! at img size: 64
Success! at img size: 128
Success! at img size: 256
Success! at img size: 512
In [9]:
def get_loader(image_size):
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),  # Resize to a standard size (can be adjusted)
        transforms.ToTensor(),  # Convert the image to a PyTorch tensor
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),  # Normalize to [-1, 1]
    ])

    batch_size = batch_sizes[int(log2(image_size / 4))]

    dataset = dset.ImageFolder(root=dataroot, transform=transform)

    dataloader = DataLoader(dataset, 
                            batch_size=batch_size, 
                            shuffle=True, 
                            num_workers=workers, 
                            drop_last=True, 
                            pin_memory=pin_memory, 
                            persistent_workers=persistent_workers, 
                            prefetch_factor=prefetch_factor)

    return dataloader, dataset
In [10]:
def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * beta + fake.detach() * (1 - beta)
    interpolated_images.requires_grad_(True)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images, alpha, train_step)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty
In [11]:
from tqdm import tqdm

def train(disc, gen, loader, dataset, step, alpha, opt_disc, opt_gen, epoch, num_epochs):
    loop = tqdm(loader, total=len(loader), miniters=50, desc=f"Epoch [{epoch + 1}/{num_epochs}]", leave=update_last)
    for real, _ in loop:
        #if batch_idx == 10:
        #    break

        real = real.to(device)
        cur_batch_size = real.shape[0]

        # Generate noise and fake data
        noise = torch.randn(cur_batch_size, z_dim, 1, 1).to(device)
        fake = gen(noise, alpha, step)

        # Train Discriminator
        disc_real = disc(real, alpha, step)
        disc_fake = disc(fake.detach(), alpha, step)
        gp = gradient_penalty(disc, real, fake, alpha, step, device)
        loss_disc = (
            -(torch.mean(disc_real) - torch.mean(disc_fake))
            + lambda_gp * gp
            + (0.001 * torch.mean(disc_real ** 2))
        )

        opt_disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()

        # Train Generator
        gen_fake = disc(fake, alpha, step)
        loss_gen = -torch.mean(gen_fake)

        opt_gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Update alpha for progressive growing
        alpha += cur_batch_size / ((prog_epochs[step] * 0.5) * len(dataset))

        alpha = min(alpha, 1)
    
    return alpha
In [12]:
import matplotlib.pyplot as plt

gen = Generator(z_dim=z_dim, in_channels=in_channels).to(device)
disc = Discriminator(in_channels=in_channels).to(device)

optimizer_gen = optim.Adam(gen.parameters(), lr=lr, betas=(0.0, 0.99))
optimizer_disc = optim.Adam(disc.parameters(), lr=lr, betas=(0.0, 0.99))

if use_pretrained:
    gen.load_state_dict(checkpoint['gen_state'])
    disc.load_state_dict(checkpoint['disc_state'])
    optimizer_gen.load_state_dict(checkpoint['gen_optim'])
    optimizer_disc.load_state_dict(checkpoint['disc_optim'])

fakes = []

if use_pretrained:
    fakes = checkpoint["fakes"]
In [13]:
gen.train()
disc.train()

change_alpha = use_pretrained

for num_epochs in prog_epochs[step:]:
    if start_train_at > image_size:
        break
    alpha = 1e-5
    if change_alpha:
        alpha = checkpoint['alpha']
        change_alpha = False

    loader, dataset = get_loader(4 * 2 ** step)
    
    print(f"Current image size: {4 * 2 ** step}")

    for epoch in range(start_epoch, num_epochs):

        alpha = train(
            disc,
            gen,
            loader,
            dataset,
            step,
            alpha,
            optimizer_disc,
            optimizer_gen,
            epoch,
            num_epochs,
        )

        with torch.no_grad():
            img = gen(fixed_noise, alpha, step) * 0.5 + 0.5

        fakes.append(img)

        if display_images:
            img = img.cpu().detach()
            img = (img - img.min()) / (img.max() - img.min())

            # Display each image
            fig, axes = plt.subplots(8, 8, figsize=(8, 8))  # Create a 8x8 grid
            axes = axes.flatten()

            for i, ax in enumerate(axes):
                image = img[i].permute(1, 2, 0).numpy()  # Rearrange dimensions to (H, W, C)
                ax.imshow(image)
                ax.axis('off')
            plt.tight_layout()
            plt.show()

        if epoch + 1 != num_epochs:
            torch.save({
                'batch_sizes': batch_sizes,
                'start_training_at': 4 * 2 ** step,
                'alpha': alpha,
                'fixed_noise': fixed_noise,
                'z_dim': z_dim,
                'in_channels': in_channels,
                'epoch': epoch + 1,
                'fakes': fakes,
                'gen_state': gen.state_dict(),
                'disc_state': disc.state_dict(),
                'gen_optim': optimizer_gen.state_dict(),
                'disc_optim': optimizer_disc.state_dict(),
            }, f"models/training_imgsize_{4 * 2 ** (step)}_zdim_{z_dim}_progression.pth")

    torch.save({
        'batch_sizes': batch_sizes,
        'start_training_at': (4 * 2 ** step) * 2,
        'alpha': 1e-5,
        'fixed_noise': fixed_noise,
        'z_dim': z_dim,
        'in_channels': in_channels,
        'epoch': 0,
        'fakes': fakes,
        'gen_state': gen.state_dict(),
        'disc_state': disc.state_dict(),
        'gen_optim': optimizer_gen.state_dict(),
        'disc_optim': optimizer_disc.state_dict(),
    }, f"models/pretrained_imgsize_{4 * 2 ** step}_zdim_{z_dim}.pth")

    step += 1

gen.eval()
disc.eval()

print("Eval mode activated")
Current image size: 64
Epoch [25/64]: 100%|██████████| 641/641 [08:46<00:00,  1.22it/s]
Epoch [26/64]: 100%|██████████| 641/641 [08:00<00:00,  1.33it/s]
Epoch [27/64]: 100%|██████████| 641/641 [07:12<00:00,  1.48it/s]
Epoch [28/64]: 100%|██████████| 641/641 [08:01<00:00,  1.33it/s]
Epoch [29/64]: 100%|██████████| 641/641 [08:00<00:00,  1.33it/s]
Epoch [30/64]: 100%|██████████| 641/641 [07:28<00:00,  1.43it/s]
Epoch [31/64]: 100%|██████████| 641/641 [07:11<00:00,  1.48it/s]
Epoch [32/64]: 100%|██████████| 641/641 [07:11<00:00,  1.48it/s]
Epoch [33/64]: 100%|██████████| 641/641 [07:11<00:00,  1.48it/s]
Epoch [34/64]: 100%|██████████| 641/641 [07:11<00:00,  1.48it/s]
Epoch [35/64]: 100%|██████████| 641/641 [07:11<00:00,  1.48it/s]
Epoch [36/64]: 100%|██████████| 641/641 [07:11<00:00,  1.48it/s]
Epoch [37/64]: 100%|██████████| 641/641 [07:11<00:00,  1.48it/s]
Epoch [38/64]: 100%|██████████| 641/641 [07:11<00:00,  1.48it/s]
Epoch [39/64]: 100%|██████████| 641/641 [07:11<00:00,  1.48it/s]
Epoch [40/64]: 100%|██████████| 641/641 [07:11<00:00,  1.48it/s]
Epoch [41/64]: 100%|██████████| 641/641 [07:11<00:00,  1.48it/s]
Epoch [42/64]: 100%|██████████| 641/641 [07:11<00:00,  1.48it/s]
Epoch [43/64]: 100%|██████████| 641/641 [07:11<00:00,  1.48it/s]
Epoch [44/64]: 100%|██████████| 641/641 [07:11<00:00,  1.48it/s]
Epoch [45/64]: 100%|██████████| 641/641 [07:11<00:00,  1.48it/s]
Epoch [46/64]: 100%|██████████| 641/641 [07:11<00:00,  1.48it/s]
Epoch [47/64]: 100%|██████████| 641/641 [07:11<00:00,  1.48it/s]
Epoch [48/64]: 100%|██████████| 641/641 [07:11<00:00,  1.48it/s]
Epoch [49/64]: 100%|██████████| 641/641 [07:11<00:00,  1.48it/s]
Epoch [50/64]: 100%|██████████| 641/641 [07:11<00:00,  1.48it/s]
Epoch [51/64]: 100%|██████████| 641/641 [07:11<00:00,  1.48it/s]
Epoch [52/64]: 100%|██████████| 641/641 [07:11<00:00,  1.48it/s]
Epoch [53/64]: 100%|██████████| 641/641 [07:11<00:00,  1.48it/s]
Epoch [54/64]: 100%|██████████| 641/641 [07:11<00:00,  1.48it/s]
Epoch [55/64]: 100%|██████████| 641/641 [07:11<00:00,  1.48it/s]
Epoch [56/64]: 100%|██████████| 641/641 [07:11<00:00,  1.48it/s]
Epoch [57/64]: 100%|██████████| 641/641 [07:11<00:00,  1.48it/s]
Epoch [58/64]: 100%|██████████| 641/641 [07:11<00:00,  1.48it/s]
Epoch [59/64]: 100%|██████████| 641/641 [07:11<00:00,  1.48it/s]
Epoch [60/64]: 100%|██████████| 641/641 [07:11<00:00,  1.48it/s]
Epoch [61/64]: 100%|██████████| 641/641 [07:11<00:00,  1.48it/s]
Epoch [62/64]: 100%|██████████| 641/641 [07:11<00:00,  1.48it/s]
Epoch [63/64]: 100%|██████████| 641/641 [07:11<00:00,  1.48it/s]
Epoch [64/64]: 100%|██████████| 641/641 [07:11<00:00,  1.48it/s]
Current image size: 128
Epoch [25/64]: 100%|██████████| 1283/1283 [19:25<00:00,  1.10it/s]
Epoch [26/64]: 100%|██████████| 1283/1283 [18:45<00:00,  1.14it/s]
Epoch [27/64]: 100%|██████████| 1283/1283 [18:45<00:00,  1.14it/s]
Epoch [28/64]: 100%|██████████| 1283/1283 [18:46<00:00,  1.14it/s]
Epoch [29/64]: 100%|██████████| 1283/1283 [18:48<00:00,  1.14it/s]
Epoch [30/64]: 100%|██████████| 1283/1283 [18:45<00:00,  1.14it/s]
Epoch [31/64]: 100%|██████████| 1283/1283 [18:44<00:00,  1.14it/s]
Epoch [32/64]: 100%|██████████| 1283/1283 [18:43<00:00,  1.14it/s]
Epoch [33/64]: 100%|██████████| 1283/1283 [18:41<00:00,  1.14it/s]
Epoch [34/64]: 100%|██████████| 1283/1283 [18:43<00:00,  1.14it/s]
Epoch [35/64]: 100%|██████████| 1283/1283 [18:45<00:00,  1.14it/s]
Epoch [36/64]: 100%|██████████| 1283/1283 [18:46<00:00,  1.14it/s]
Epoch [37/64]: 100%|██████████| 1283/1283 [18:46<00:00,  1.14it/s]
Epoch [38/64]: 100%|██████████| 1283/1283 [18:45<00:00,  1.14it/s]
Epoch [39/64]: 100%|██████████| 1283/1283 [20:08<00:00,  1.06it/s]
Epoch [40/64]: 100%|██████████| 1283/1283 [18:47<00:00,  1.14it/s]
Epoch [41/64]: 100%|██████████| 1283/1283 [18:46<00:00,  1.14it/s]
Epoch [42/64]: 100%|██████████| 1283/1283 [20:09<00:00,  1.06it/s]
Epoch [43/64]: 100%|██████████| 1283/1283 [18:47<00:00,  1.14it/s]
Epoch [44/64]: 100%|██████████| 1283/1283 [19:44<00:00,  1.08it/s]
Epoch [45/64]: 100%|██████████| 1283/1283 [19:09<00:00,  1.12it/s]
Epoch [46/64]: 100%|██████████| 1283/1283 [20:08<00:00,  1.06it/s]
Epoch [47/64]: 100%|██████████| 1283/1283 [19:24<00:00,  1.10it/s]
Epoch [48/64]: 100%|██████████| 1283/1283 [18:42<00:00,  1.14it/s]
Epoch [49/64]: 100%|██████████| 1283/1283 [18:43<00:00,  1.14it/s]
Epoch [50/64]: 100%|██████████| 1283/1283 [18:43<00:00,  1.14it/s]
Epoch [51/64]: 100%|██████████| 1283/1283 [18:45<00:00,  1.14it/s]
Epoch [52/64]: 100%|██████████| 1283/1283 [19:14<00:00,  1.11it/s]
Epoch [53/64]: 100%|██████████| 1283/1283 [21:56<00:00,  1.03s/it]
Epoch [54/64]: 100%|██████████| 1283/1283 [22:46<00:00,  1.07s/it]
Epoch [55/64]: 100%|██████████| 1283/1283 [21:02<00:00,  1.02it/s]
Epoch [56/64]: 100%|██████████| 1283/1283 [21:32<00:00,  1.01s/it]
Epoch [57/64]: 100%|██████████| 1283/1283 [20:41<00:00,  1.03it/s]
Epoch [58/64]: 100%|██████████| 1283/1283 [19:58<00:00,  1.07it/s]
Epoch [59/64]: 100%|██████████| 1283/1283 [20:03<00:00,  1.07it/s]
Epoch [60/64]: 100%|██████████| 1283/1283 [19:58<00:00,  1.07it/s]
Epoch [61/64]: 100%|██████████| 1283/1283 [19:50<00:00,  1.08it/s]
Epoch [62/64]: 100%|██████████| 1283/1283 [20:05<00:00,  1.06it/s]
Epoch [63/64]: 100%|██████████| 1283/1283 [19:54<00:00,  1.07it/s]
Epoch [64/64]: 100%|██████████| 1283/1283 [19:49<00:00,  1.08it/s]
Eval mode activated
In [65]:
img_list = []

for x in [x.detach().cpu() for x in fakes]: # To shorten number of images: [x[:8] for x in fakes]
    x = torch.nn.functional.interpolate(x, size=(128, 128), mode="nearest")
    img_list.append(vutils.make_grid(x, padding=2, normalize=True))
In [66]:
import matplotlib.animation as animation
from IPython.display import HTML

fig = plt.figure(figsize=(8, 8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

ani.save("anime.gif", writer='pillow', fps=10)

HTML(ani.to_jshtml())
Animation size has reached 21050564 bytes, exceeding the limit of 20971520.0. If you're sure you want a larger animation embedded, set the animation.embed_limit rc parameter to a larger value (in MB). This and further frames will be dropped.
Out[66]:
No description has been provided for this image
No description has been provided for this image
In [24]:
# Grab a batch of real images from the dataloader
dataloader, dataset = get_loader(image_size=128)
real_batch = torch.stack([dataset[i][0] for i in range(64)]).to(device)
#real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch, padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()
No description has been provided for this image
In [737]:
gen.eval()
disc.eval()
print("Eval mode set.")
Eval mode set.
In [ ]:
from torchvision.transforms.functional import to_pil_image

noise = torch.randn(1, z_dim, 1, 1).to(device)

output: torch.Tensor = gen(noise, 1, step - 1) * 0.5 + 0.5
output = output.squeeze(0)
img = to_pil_image(output)

plt.figure(figsize=(2,2))
plt.axis("off")
plt.imshow(img)
plt.show()
No description has been provided for this image
In [ ]:
"""def interpolate_points(p1, p2, n_steps=100):
    ratios = np.linspace(0, 1, num=n_steps)[1:]
    vectors = [p1]
    for ratio in ratios:
        v = (1.0 - ratio) * p1 + ratio * p2
        vectors.append(v)
    return torch.stack(vectors, dim=0)"""
In [ ]:
"""point1 = torch.randn(z_dim, 1, 1)
point2 = torch.randn(z_dim, 1, 1)

points = interpolate_points(point1, point2, n_steps=100)

images = [to_pil_image((gen(p.to(device), 1, step - 1) * 0.5 + 0.5).squeeze(0)) for p in points]
images[0].save("progress.gif", save_all=True, append_images=images[1:], duration=10, loop=2)"""
In [ ]:
"""
start_train_at = 4 * 2 ** step
epoch = 0 # Specify the epoch to resume training at

torch.save({
    'batch_sizes': batch_sizes,
    'start_training_at': start_train_at,
    'fixed_noise': fixed_noise,
    'z_dim': z_dim,
    'in_channels': in_channels,
    'factors': factors,
    'epoch': epoch,
    'fakes': fakes,
    'gen_state': gen.state_dict(),
    'disc_state': disc.state_dict(),
    'gen_optim': optimizer_gen.state_dict(),
    'disc_optim': optimizer_disc.state_dict(),
}, f"model_imgsize_{4 * 2 ** step}_continue_epoch_{epoch}.pth")
"""